import os
import numpy as np
import pandas as pd
import random

from config import *
random.seed(178006)

print('-------- GETTING CHIS UNINSURED DATA ------------------')
print()

# ------------- import cuts for categorical vars -----------------------------

age_cuts = pd.read_pickle(CUTS_AGE)
inc_cuts = pd.read_pickle(CUTS_INC)
acg_cuts = pd.read_pickle(CUTS_ACG)

age_cuts = [-1.0] + [int(cut) for cut in age_cuts.values] + [999]
inc_cuts = [-1.0] + [float(cut) for cut in inc_cuts.values] + [999.9]
acg_cuts = [-1.0] + [float(cut) for cut in acg_cuts.values] + [999.9]

# ------------- import all CHIS adult data -----------------------------------

chis_files = {
        2014: 'chis14_adult_stata/CHIS14_adult_stata/Data/ADULT.dta',
        2015: 'CHIS15_adult_stata_dec17/CHIS15_adult_stata_dec17/Data/ADULT.dta',
        2016: 'CHIS16_adult_stata/CHIS16_adult_stata/Data/ADULT.dta'}

dfs = []
for year in range(2014, 2017):
    path = f'{CHIS_PATH}/{year}/{chis_files[year]}'
    df_ = pd.read_stata(path, convert_categoricals=False)
    df_['year'] = year
    dfs += [df_]

df = pd.concat(dfs, sort=False).reset_index(drop=True)

# ------------- clean data ---------------------------------------------------

cols = ['year',         # year
        'ur_omb',       # metropolitan indicator (based on MSA/CBSA)
        'ur_clrt2',     # rural/urban indicator (based on ZIP)
        'marit',        # married
        'famtyp_p',     # family type
        'hhsize_p1',    # household size
        'famsize2_p1',  # family size
        'povgwd_p',     # fraction of poverty level
        'ak22_p',       # annual household income
        'srage_p1',     # self-reported age
        'ins',          # currently insured
        'insany',       # insurance status within last 12 months
        'uninsany',     # uninsurance status within last 12 months
        'ab1',          # general health condition
        'ab17',         # asthma status
        'ab22',         # diabetes status
        'ab29',         # high blood pressure status
        'ab34',         # heart disease status
        'smkcur',       # smoker status
        'bmi_p',        # bmi
        'ovrwt',        # overweight or obese indicator
        'acmdnum',      # number of doctor visits in previous year
        'instype',      # insurance market
        'ak8',          # number of people at company
        'ah104',        # exchange status
        'ah109']        # spouse exchange status

df = df[cols].rename(
        {'famtyp_p': 'family_type',
         'hhsize_p1': 'household_size',
         'famsize2_p1': 'family_size',
         'ab1': 'health_status',
         'bmi_p': 'bmi',
         'acmdnum': 'doctor_visit_count'},
        axis=1)

# create indicator variables
df['is_married'] = (df.marit == 1).map(int)
df['has_children'] = (df.family_type.isin([4, 5])).map(int)
df['metro'] = (df.ur_omb == 1).map(int)
df['urban'] = (df.ur_clrt2 == 1).map(int)
df['uninsured'] = (df.ins == 2).map(int)

df['asthma_status'] = (df.ab17 == 1).map(int)
df['diabetes_status'] = (df.ab22 == 1).map(int)
df['blood_pressure_status'] = (df.ab29 == 1).map(int)
df['heart_disease_status'] = (df.ab34 == 1).map(int)
df['smoking_status'] = (df.smkcur == 1).map(int)
df['overweight_status'] = (df.ovrwt == 1).map(int)

df['on_exchange'] = (df.ah104 == 2).map(int)

# ------------- bin the age and income vars according to market data bins ----

# age bins
df['age'] = 0
for i in range(1, len(age_cuts)):
    low = age_cuts[i - 1] + 1
    high = age_cuts[i]
    df.loc[(df.srage_p1 >= low) &
           (df.srage_p1 <= high),
           'age'] = i + 1 # age bins start at 2......

assert (df.age != 0).all()

# income bins
df['income'] = 0
for i in range(1, len(inc_cuts)):
    low = inc_cuts[i - 1]
    high = inc_cuts[i]
    df.loc[(df.povgwd_p > low) &
           (df.povgwd_p <= high),
           'income'] = i

assert (df.income != 0).all()

# ------------- make acg bins from chronic condition status vars -------------

# import chronic condition data
dfs = []
for year in range(2014, 2017):
    dfs += [pd.read_csv(f'{CHRONIC_CONDITION_PATH}/CCMed{year}.csv')]

df_comorb = pd.concat(dfs, sort=False).reset_index(drop=True)

# clean data
df_comorb['asthma_status'] = df_comorb.Asthma.map(int)
df_comorb['diabetes_status'] = df_comorb.Diabetes.map(int)
df_comorb['blood_pressure_status'] = df_comorb.HT.map(int)   # Hypertension
df_comorb['heart_disease_status'] = np.where(df_comorb.AMI | # Accute Myocardial Infarction
                                             df_comorb.AF  | # Atrial Fibrillation
                                             df_comorb.HF  | # Heart Failure
                                             df_comorb.ICD | # Ischemic Heart Disease
                                             df_comorb.STIA, # Stroke / Transient Ischemic Attack
                                             1, 0)

covars = ['asthma_status', 'diabetes_status', 'blood_pressure_status', 'heart_disease_status']

df_comorb = df_comorb[['personkey', 'year'] + covars]

# import subscriber data (to merge on ACG scores)
df_subs = pd.read_csv(SUBSCRIBER_PATH)
df_subs = df_subs[['personkey', 'year', 'sum_concurrent_risk']]

df_comorb = df_comorb.merge(df_subs, on=['personkey', 'year'])

# acg bins
df_comorb['acg'] = 0
for i in range(1, len(acg_cuts)):
    low = acg_cuts[i - 1]
    high = acg_cuts[i]
    df_comorb.loc[(df_comorb.sum_concurrent_risk > low) &
                  (df_comorb.sum_concurrent_risk <= high),
                  'acg'] = i

assert (df_comorb.acg != 0).all()

df_comorb = df_comorb[['acg'] + covars].dropna()

df_comorb['health_interaction'] = df_comorb[covars].values @ np.array([[1, 2, 4, 8]]).T
df['health_interaction'] = df[covars].values @ np.array([[1, 2, 4, 8]]).T
df['acg_quartiles'] = 0

# randomly assign ACGs to subscribers to match the distributions
# corresponding to the ACG/chronic condition observations
output = []

for i in sorted(df_comorb.health_interaction.unique()):
    dist = df_comorb[df_comorb.health_interaction == i].acg.value_counts().reset_index()
    dist.columns = ['acg', 'num']

    n = sum(df.health_interaction == i)
    
    dist.num = (n * dist.num / dist.num.sum()).round(0).map(int)
    dist.loc[0, 'num'] -= (dist.num.sum() - n)

    assert dist.num.sum() == n

    indices = df[df.health_interaction == i].index.to_list()
    random.shuffle(indices)

    j = 0

    for n, row in dist.iterrows():
        df.loc[indices[j: j + row.num], 'acg_quartiles'] = row.acg
        j += row.num

    # create output table
    comorbs = df[df.health_interaction == i].drop_duplicates().reset_index(drop=True)
    dist['percent'] = (100 * dist.num / dist.num.sum()).round(0).map(int)

    output += [(comorbs.loc[0, 'asthma_status'],
                comorbs.loc[0, 'diabetes_status'],
                comorbs.loc[0, 'blood_pressure_status'],
                comorbs.loc[0, 'heart_disease_status'],
                sum(df_comorb.health_interaction == i),
                sum(df.health_interaction == i),
                dist.loc[dist.acg == 1, 'percent'].values[0],
                dist.loc[dist.acg == 2, 'percent'].values[0],
                dist.loc[dist.acg == 3, 'percent'].values[0],
                dist.loc[dist.acg == 4, 'percent'].values[0])]

output = pd.DataFrame(output, columns = ('asthma', 'diabetes', 'blood', 'heart',
                                         'n_subs', 'n_chis', 'acg_1', 'acg_2',
                                         'acg_3', 'acg_4'))

display('ACG quartile distributions by chronic condition status:')
display(output)
display()

assert (df.acg_quartiles != 0).all()

# ------------- count observations for condition vars and --------------------
# ------------- condition/demographic vars, and merge ------------------------

df_supergroups = df.groupby(conditionals).sum().uninsured.reset_index()
df_groups = df.groupby(conditionals + demographics).sum().uninsured.reset_index()
df_groups = df_groups.merge(df_supergroups, on=conditionals, how='left')
df_groups = df_groups.rename({'uninsured_x': 'n_unins_group',
                              'uninsured_y': 'n_unins'},
                             axis=1)
df_groups['conditional_prob'] = df_groups.n_unins_group / df_groups.n_unins

# make probs uniform if missing! (i.e. if n_unins is 0)
#df_groups.conditional_prob = np.where(df_groups.conditional_prob.isna(),
#                                      1 / df_groups.groupby(conditionals).n_unins.transform('count'),
#                                      df_groups.conditional_prob)

cols = conditionals + ['n_unins'] + demographics + ['n_unins_group', 'conditional_prob']
df_groups = df_groups[cols]

test = df_groups.groupby(conditionals).sum().conditional_prob
assert (abs(test - 1) <= 1e-10).all()

# ------------- gen cross data to make complete demographic list -------------

df_full = pd.DataFrame({'key': [1]})
for c in conditionals + demographics:
    levels = df_groups[c].drop_duplicates().reset_index()
    levels['key'] = 1
    levels = levels[[c, 'key']]
    df_full = df_full.merge(levels, on='key')

df_full = df_full.drop('key', axis=1)

df_full = df_full.merge(df_groups, on=conditionals + demographics, how='left')

test = df_full.groupby(conditionals).sum().conditional_prob
assert (abs(test - 1) <= 1e-10).all()

# ------------- fill in missing and zero values of conditional_prob ----------

n_missing = sum(df_full.conditional_prob.isna())
n_zero = sum(df_full.conditional_prob == 0)

df_full.loc[df_full.conditional_prob == 0, 'conditional_prob'] = np.nan

print(sum(df_full.conditional_prob.isna()), 'probabilities to impute')
display(n_missing, 'missing')
display(n_zero, 'zero')
print()

def impute_by(var, df_full):
    vars_execpt_var = [v for v in conditionals + demographics if v != var]
    impute_by_var = df_full.groupby(vars_execpt_var).conditional_prob.mean().reset_index()
    impute_by_var = impute_by_var.rename({'conditional_prob': 'imputed_prob'}, axis=1)
    df_full = df_full.merge(impute_by_var, on=vars_execpt_var, how='left')
    df_full.conditional_prob = np.where(df_full.conditional_prob.isna(),
                                        df_full.imputed_prob,
                                        df_full.conditional_prob)
    df_full = df_full.drop('imputed_prob', axis=1)
    return df_full

for var in unins_impute_order:
    df_full = impute_by(var, df_full)
    left_to_impute = sum(df_full.conditional_prob.isna())

    display(f'imputed by {var} -', left_to_impute, 'left to impute')

    if left_to_impute == 0:
        break

assert sum(df_full.conditional_prob.isna()) == 0

# renormalize probabilities
norm = df_full.groupby(conditionals).conditional_prob.sum().reset_index()
norm = norm.rename({'conditional_prob': 'norm'}, axis=1)
df_full = df_full.merge(norm, on=conditionals, how='left')
df_full.conditional_prob /= df_full.norm
df_full = df_full.drop('norm', axis=1)

test = df_full.groupby(conditionals).sum().conditional_prob
assert (abs(test - 1) <= 1e-10).all()

df_full['hh_size_adjustment'] = 1 + df_full.is_married

# make sure there aren't duplicates
assert len(df_full) == len(df_full[conditionals + demographics].drop_duplicates())

# save dataset
df_out = df_full[conditionals +
                 demographics +
                 ['n_unins',
                  'n_unins_group',
                  'hh_size_adjustment',
                  'conditional_prob']]
df_out.to_csv('../../data/2_chis_uninsured.csv', index=False)
